# -*- coding: utf-8 -*-
"""
Created on Fri Oct  1 18:38:08 2021
因 tf2 运行速度慢
将 神经网络 从 tf2 中解放出来
直接用 矩阵 来作为 神经网络
@author: Waikikilick
"""

import tensorflow as tf
from tensorflow.keras import layers, Sequential
import numpy as np
import random

import warnings
from tensorflow import keras
from scipy.linalg import expm
warnings.filterwarnings("ignore", category=DeprecationWarning)
from time import *

class env(object):
    def  __init__(self, 
        action_space = [0,1], #允许的动作，默认两个分立值，只是默认值，真正值由调用时输入
        dt = 0.1,
        ): 
        self.action_space = action_space
        self.n_actions = len(self.action_space)
        self.n_features = 4 #描述状态所用的长度
        self.target_psi =  np.mat([[1], [0]], dtype=complex) #最终的目标态为 |0>,  np.array([1,0,0,0])
        self.s_x = np.mat([[0, 1], [1, 0]], dtype=complex)
        self.s_z = np.mat([[1, 0], [0, -1]], dtype=complex)
        self.dt = dt
        self.training_set, self.validation_set, self.testing_set = self.psi_set()
        
        
        
    def psi_set(self):
        
        theta_num = 6 # 除了 0 和 Pi 两个点之外，点的数量
        varphi_num = 21 # varphi 角度一圈上的点数
        # 总点数为 theta_num * varphi_num + 2(布洛赫球两极) # 6 * 21 + 2 = 512
        
        theta = np.delete(np.linspace(0,np.pi,theta_num+1,endpoint=False),[0])
        varphi = np.linspace(0,np.pi*2,varphi_num,endpoint=False) 
        
        psi_set = []
        for ii in range(theta_num):
            for jj in range(varphi_num):
                psi_set.append(np.mat([[np.cos(theta[ii]/2)],[np.sin(theta[ii]/2)*(np.cos(varphi[jj])+np.sin(varphi[jj])*(0+1j))]]))
        psi_set.append(np.mat([[1], [0]], dtype=complex))
        psi_set.append(np.mat([[0], [1]], dtype=complex))
        random.shuffle(psi_set) # 打乱点集
    
        training_set = psi_set[0:32]
        validation_set = psi_set[32:64]
        testing_set = psi_set[64:128]
        
        return training_set, validation_set, testing_set
        
    
    def reset(self, init_psi): # 在一个新的回合开始时，归位到开始选中的那个点上
        
        init_state = np.array([init_psi[0,0].real, init_psi[1,0].real, init_psi[0,0].imag, init_psi[1,0].imag])
        # np.array([1实，2实，1虚，2虚])
        return init_state
    
    
    def step(self, state, action, nstep):
        
        psi = np.array([state[0:int(len(state) / 2)] + state[int(len(state) / 2):int(len(state))] * 1j])
        psi = psi.T
        psi = np.mat(psi) 
        
        H =  float(action)* self.s_z + 1 * self.s_x
        U = expm(-1j * H * self.dt) 
        psi = U * psi  # next state

        fid = (np.abs(psi.H * self.target_psi) ** 2).item(0).real  
                
        done = (fid>=0.9999 or nstep >= 20) 

        #再将量子态的 psi 形式恢复到 state 形式。 # 因为网络输入不能为复数，否则无法用寻常基于梯度的算法进行反向传播
    
        psi = np.array(psi)
        psi_T = psi.T
        state = np.array(psi_T.real.tolist()[0] + psi_T.imag.tolist()[0]) 

        return state, done, fid  
    
def relu(vector):
    for i in range(len(vector)):
        vector[i] = np.maximum(0,vector[i])
    return vector

def agent(x):
    out = np.mat(x) * net[0].numpy() + net[1].numpy()
    out = relu(out)
    out = out * net[2].numpy() + net[3].numpy()
    out = relu(out)
    out = out * net[4].numpy() + net[5].numpy()
    out = relu(out)
    action = np.argmax(out)
    return action

def testing_time(): # 测试 测试集中的点 得到 保真度 分布
    print('\n时间测试中, 请稍等...')
    
    data = []
    
    for test_init_psi in testing_set:
        
        fid_max = 0
        fid_list_tem = []
        observation = env.reset(test_init_psi)
        nstep = 0 
        start_time = time()
        while True:
            action = agent(observation) 
            observation, done, fid = env.step(observation, action, nstep)  
            nstep += 1
            fid_list_tem.append(fid)
            if done:
                end_time = time()
                break
        
        test_fid = max(fid_list_tem)
        max_index = fid_list_tem.index(max(fid_list_tem))
        test_time = (end_time - start_time)*max_index/len(fid_list_tem)
        data.append((test_time, test_fid, max_index))
        
    return data

if __name__ == "__main__":
    
    dt = np.pi/10
    network = keras.models.load_model('dqn_tf2_single_qubit_to_0_saved_model') #导入训练好的网络 
    net = network.trainable_variables
    env = env(action_space = list(range(4)),   #允许的动作数 0 ~ 4-1 也就是 4 个
               dt = dt)
    
    testing_set = env.testing_set   
    # print(testing_set)
    
    # 测试
    data = testing_time()
    data.sort()
    
    time = 0
    fid = 0
    max_index = 0
    time_list = []
    fid_list = []
    max_index_list = []
    
    for i in data:
        time += i[0]
        time_list.append(i[0])
        fid += i[1]
        fid_list.append(i[1])
        max_index += i[2]
        
    print('the mean time is: ', time/len(testing_set))
    print('the mean fid is: ', fid/len(testing_set))
    print('the mean max_index is: ', max_index/len(testing_set))

    for i in (data):
        print(i)  
        
    # print(np.array(time_list))
    # print(np.array(fid_list))
     
    # the mean time is:  0.01196698737995965
    # time: 0.         0.00095387 0.00190456 0.00190476 0.00238845 0.00306719 0.00476207 0.00477183 0.00533958 0.00640242 0.00652217 0.00665514 0.00666579 0.0076081  0.00762422 0.00763303 0.0085703  0.00857142 0.00857397 0.00917196 0.00952051 0.00953958 0.00969489 0.00998974 0.01001167 0.01047069 0.01047418 0.01047568 0.01049604 0.01077931 0.01118834 0.01120732 0.01142924 0.0114321  0.01238537 0.0128291 0.01332935 0.01333189 0.01335192 0.01349565 0.01428536 0.01428706 0.01482609 0.01522664 0.01523808 0.01619007 0.016192   0.0171408 0.0171504  0.01806459 0.01806977 0.01809522 0.0181004  0.01862935 0.018677   0.01898766 0.01901558 0.01904556 0.01904964 0.01904964 0.01915925 0.02033064 0.02189952 0.02465348]
    
    # the mean fid is:  0.9968424989421337
    # fid: 0.99992706 0.99412025 0.99669081 0.99937223 0.99810897 0.99844001 0.99953221 0.99987183 0.99793122 0.99695026 0.99262794 0.99407401 0.99385914 0.99740712 0.99964585 0.99518491 0.998163   0.997538 0.9874814  0.99560174 0.9984762  0.99652099 0.99527229 0.99586106 0.99360692 0.99711408 0.99679618 0.99655595 0.99813032 0.99739846 0.99683259 0.99541446 0.99778098 0.99891509 0.99755767 0.99047974 0.9975986  0.99816336 0.99721183 0.99949408 0.99597828 0.99629809 0.99815798 0.99297227 0.99003232 0.99765897 0.99729656 0.99855401 0.99751055 0.99925873 0.99662283 0.99723935 0.99812714 0.99939426 0.99526946 0.99956646 0.99885223 0.99702491 0.99596365 0.99801181 0.99891761 0.99925065 0.99558892 0.99666409
    
    # (0.0, 0.9999270558227643)
    # (0.0009538673219226656, 0.9941202501599435)
    # (0.0019045557294573104, 0.9966908112314136)
    # (0.0019047600882393973, 0.9993722344613314)
    # (0.00238845461890811, 0.9981089656054247)
    # (0.0030671869005475727, 0.9984400062918157)
    # (0.004762070519583566, 0.9995322089914118)
    # (0.004771834328061058, 0.9998718319486674)
    # (0.0053395770844959075, 0.9979312184007922)
    # (0.0064024244035993305, 0.9969502610376323)
    # (0.006522167296636672, 0.9926279402780416)
    # (0.006655136744181315, 0.9940740148379621)
    # (0.006665786107381185, 0.9938591352090292)
    # (0.007608095804850261, 0.9974071213655467)
    # (0.0076242174421037945, 0.9996458470759457)
    # (0.007633027576264881, 0.9951849133898415)
    # (0.00857029642377581, 0.9981629978756337)
    # (0.008571420397077287, 0.9975380002985932)
    # (0.008573974881853377, 0.9874813991752841)
    # (0.00917196273803711, 0.995601744109785)
    # (0.009520507994152251, 0.998476197954729)
    # (0.009539581480480376, 0.9965209922936167)
    # (0.009694894154866537, 0.9952722878032612)
    # (0.009989738464355469, 0.9958610600314679)
    # (0.010011672973632812, 0.9936069166603954)
    # (0.010470685504731677, 0.9971140806563806)
    # (0.010474182310558501, 0.9967961777931901)
    # (0.010475680941627138, 0.9965559500448835)
    # (0.010496037346976144, 0.998130322809622)
    # (0.010779312678745814, 0.9973984556862291)
    # (0.011188336781093053, 0.9968325933582057)
    # (0.011207319441295806, 0.9954144565671021)
    # (0.011429241725376673, 0.9977809764517119)
    # (0.011432102748325892, 0.9989150867311203)
    # (0.012385368347167969, 0.9975576668825328)
    # (0.01282909938267299, 0.9904797399887192)
    # (0.013329346974690756, 0.9975986010953874)
    # (0.013331890106201172, 0.9981633561543155)
    # (0.013351917266845703, 0.9972118284450591)
    # (0.013495649610246931, 0.9994940777834056)
    # (0.014285360063825334, 0.9959782807768726)
    # (0.014287063053676061, 0.9962980944854838)
    # (0.014826093401227678, 0.9981579838601582)
    # (0.015226636614118303, 0.9929722736658954)
    # (0.015238080705915178, 0.990032322197728)
    # (0.016190074739002046, 0.9976589651790656)
    # (0.0161920047941662, 0.9972965617372226)
    # (0.017140797206333706, 0.9985540058613029)
    # (0.017150402069091797, 0.9975105535828962)
    # (0.018064589727492558, 0.9992587280103296)
    # (0.018069766816638765, 0.9966228343376382)
    # (0.018095220838274275, 0.9972393472546977)
    # (0.01810039792742048, 0.9981271448719122)
    # (0.018629346575055803, 0.9993942575169792)
    # (0.0186769962310791, 0.9952694587185998)
    # (0.018987655639648438, 0.9995664612046153)
    # (0.019015584673200334, 0.9988522274802045)
    # (0.019045557294573103, 0.9970249102974611)
    # (0.019049644470214844, 0.9959636548091747)
    # (0.019049644470214844, 0.9980118098273612)
    # (0.019159248897007534, 0.9989176140970424)
    # (0.020330644789196196, 0.9992506532163739)
    # (0.021899518512544177, 0.9955889189757052)
    # (0.024653480166480654, 0.996664087573618)